# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
from scipy.signal import welch

from mne.time_frequency import psd_array_multitaper, psd_array_welch
from mne.time_frequency.multitaper import _psd_from_mt
from mne.time_frequency.psd import _median_biases
from mne.utils import catch_logging


def test_psd_nan():
    """Test handling of NaN in psd_array_welch."""
    n_samples, n_fft, n_overlap = 2048, 1024, 512
    x = np.random.RandomState(0).randn(1, n_samples)
    psds, freqs = psd_array_welch(
        x[:, : n_fft + n_overlap], float(n_fft), n_fft=n_fft, n_overlap=n_overlap
    )
    x[:, n_fft + n_overlap :] = np.nan  # what Raw.get_data() will give us
    psds_2, freqs_2 = psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
    assert_allclose(freqs, freqs_2)
    assert_allclose(psds, psds_2)
    # 1-d
    psds_2, freqs_2 = psd_array_welch(
        x[0], float(n_fft), n_fft=n_fft, n_overlap=n_overlap
    )
    assert_allclose(freqs, freqs_2)
    assert_allclose(psds[0], psds_2)
    # defaults
    with catch_logging() as log:
        psd_array_welch(x, float(n_fft), verbose="debug")
    log = log.getvalue()
    assert "using 256-point FFT on 256 samples with 0 overlap" in log
    assert "hamming window" in log


def test_bad_annot_handling():
    """Make sure results equivalent with/without Annotations."""
    n_per_seg = 256
    n_chan = 3
    n_times = 5 * n_per_seg
    x = np.random.default_rng(seed=42).standard_normal(size=(n_chan, n_times))
    want = psd_array_welch(x, sfreq=100)
    # now simulate an annotation that breaks up the array into unequal spans. Using
    # `n_per_seg` as the cut point is unrealistic/idealized, but it allows us to test
    # whether we get results ~identical to `want` (which we should in this edge case)
    x2 = np.concatenate(
        (x[..., :n_per_seg], np.full((n_chan, 1), np.nan), x[..., n_per_seg:]), axis=-1
    )
    got = psd_array_welch(x2, sfreq=100)
    # freqs should be identical
    np.testing.assert_array_equal(got[1], want[1])
    # powers should be very very close
    np.testing.assert_allclose(got[0], want[0], rtol=1e-15, atol=0)


def _make_psd_data():
    """Make noise data with sinusoids in 2 out of 7 channels."""
    rng = np.random.default_rng(0)
    n_chan, n_times, sfreq = 7, 8000, 1000
    data = 0.1 * rng.random((n_chan, n_times))
    times = np.arange(n_times) / sfreq
    sinusoid_freqs = [8.0, 50.0]
    chs_with_sinusoids = [0, 1]
    for ix, freq in zip(chs_with_sinusoids, sinusoid_freqs):
        data[ix, :] += 2 * np.sin(np.pi * 2.0 * freq * times)
    return data, sfreq, sinusoid_freqs


@pytest.mark.parametrize(
    "psd_func, psd_kwargs",
    [
        (psd_array_welch, dict(n_fft=128, window="hann")),
        (psd_array_multitaper, dict(low_bias=True)),
    ],
)
def test_psd(psd_func, psd_kwargs):
    """Tests the welch and multitaper PSD."""
    data, sfreq, sinusoid_freqs = _make_psd_data()
    # prepare kwargs
    psd_kwargs.update(dict(fmin=2, fmax=70, verbose="debug"))
    # compute PSD and test basic conformity
    with catch_logging() as log:
        psds, freqs = psd_func(data, sfreq, **psd_kwargs)
    if psd_func is psd_array_welch:
        log = log.getvalue()
        n_fft = psd_kwargs["n_fft"]
        assert f"{n_fft}-point FFT on {n_fft} samples with 0 overl" in log
        assert "hann window" in log
    assert psds.shape == (data.shape[0], len(freqs))
    assert np.sum(freqs < 0) == 0
    assert np.sum(psds < 0) == 0
    # Is power found where it should be?
    ixs_max = np.argmax(psds, axis=1)
    for ixmax, ifreq in zip(ixs_max, sinusoid_freqs):
        # Find nearest frequency to the "true" freq
        ixtrue = np.argmin(np.abs(ifreq - freqs))
        assert np.abs(ixmax - ixtrue) < 2


def test_psd_array_welch_nperseg_kwarg():
    """Test n_per_seg and padding in psd_array_welch()."""
    data, sfreq, _ = _make_psd_data()
    # prepare kwargs
    kwargs = dict(fmin=2, fmax=70, n_per_seg=128)
    # test n_per_seg in psd_array_welch (and padding)
    psds1, freqs1 = psd_array_welch(data, sfreq, n_fft=128, **kwargs)
    psds2, freqs2 = psd_array_welch(data, sfreq, n_fft=256, **kwargs)
    assert len(freqs1) == np.floor(len(freqs2) / 2.0)
    assert psds1.shape[-1] == np.floor(psds2.shape[-1] / 2.0)
    # test bad n_fft
    with pytest.raises(ValueError, match="n_fft is not allowed to be > n_tim"):
        kwargs.update(n_per_seg=None)
        bad_n_fft = int(data.shape[-1] * 1.1)
        psd_array_welch(data, sfreq, n_fft=bad_n_fft, **kwargs)
    # test bad n_overlap
    with pytest.raises(ValueError, match="n_overlap cannot be greater"):
        kwargs.update(n_per_seg=64)
        psd_array_welch(data, sfreq, n_fft=128, n_overlap=90, **kwargs)
    # test bad fmin/fmax
    with pytest.raises(ValueError, match="No frequencies found"):
        psd_array_welch(data, sfreq, fmin=10, fmax=1)


def test_complex_multitaper():
    """Test complex-valued multitaper output."""
    data, sfreq, _ = _make_psd_data()
    psd_complex, freq_complex, weights = psd_array_multitaper(
        data[:4, :500], sfreq, output="complex"
    )
    psd, freq = psd_array_multitaper(data[:4, :500], sfreq, output="power")
    assert_array_equal(freq_complex, freq)
    assert psd_complex.ndim == 3  # channels x tapers x freqs
    psd_from_complex = _psd_from_mt(psd_complex, weights)
    assert_allclose(psd_from_complex, psd)


# Copied from SciPy
def _median_bias(n):
    ii_2 = 2 * np.arange(1.0, (n - 1) // 2 + 1)
    return 1 + np.sum(1.0 / (ii_2 + 1) - 1.0 / ii_2)


@pytest.mark.parametrize("crop", (False, True))
def test_psd_array_welch_average_kwarg(crop):
    """Test `average` kwarg of psd_array_welch()."""
    data, sfreq, _ = _make_psd_data()
    # prepare kwargs
    n_per_seg = 32
    kwargs = dict(fmin=0, fmax=np.inf, n_fft=64, n_per_seg=n_per_seg, n_overlap=0)
    # optionally crop data by n_per_seg so that we are sure to test both an
    # odd number and an even number of estimates (for median bias)
    if crop:
        data = data[..., :-n_per_seg]
    # run with average=mean/median/None
    psds_mean, freqs_mean = psd_array_welch(data, sfreq, average="mean", **kwargs)
    psds_median, freqs_median = psd_array_welch(data, sfreq, average="median", **kwargs)
    psds_unagg, freqs_unagg = psd_array_welch(data, sfreq, average=None, **kwargs)
    # Frequencies should be equal across all "average" types, as we feed in
    # the exact same data.
    assert_array_equal(freqs_mean, freqs_median)
    assert_array_equal(freqs_mean, freqs_unagg)
    # For `average=None`, the last dimension contains the un-aggregated
    # segments.
    assert psds_mean.shape == psds_median.shape
    assert psds_mean.shape == psds_unagg.shape[:-1]
    assert_array_equal(psds_mean, psds_unagg.mean(axis=-1))
    # Compare with manual median calculation (_median_bias copied from SciPy)
    bias = _median_bias(psds_unagg.shape[-1])
    assert_allclose(psds_median, np.median(psds_unagg, axis=-1) / bias)
    # check shape of unagg
    n_chan, n_times = data.shape
    n_freq = len(freqs_unagg)
    n_segs = np.ceil(n_times / n_per_seg).astype(int)
    assert n_segs % 2 == (1 if crop else 0)
    assert psds_unagg.shape == (n_chan, n_freq, n_segs)


@pytest.mark.parametrize("n", (2, 3, 5, 8, 12, 13, 14, 15))
def test_median_biases(n):
    """Test vectorization of median_biases."""
    want_biases = np.concatenate(
        ([1.0, 1.0], [_median_bias(ii) for ii in range(2, n + 1)])
    )
    got_biases = _median_biases(n)
    assert_allclose(want_biases, got_biases)
    assert_allclose(got_biases[n], _median_bias(n))
    assert_allclose(got_biases[:3], 1.0)


@pytest.mark.slowtest
def test_compares_psd():
    """Test PSD estimation on raw for plt.psd and scipy.signal.welch."""
    data, sfreq, _ = _make_psd_data()
    n_fft = 2048
    fmin, fmax = 2, 70
    # Compute PSD with psd_array_welch
    psds_mne, freqs_mne = psd_array_welch(
        data, sfreq, fmin=fmin, fmax=fmax, n_fft=n_fft
    )
    # Compute psds with scipy.signal.welch
    freqs_scipy, psds_scipy = welch(
        data, fs=sfreq, nperseg=n_fft, noverlap=0, window="hamming"
    )
    # restrict to the relevant frequencies
    mask = (freqs_scipy >= fmin) & (freqs_scipy <= fmax)
    freqs_scipy = freqs_scipy[mask]
    psds_scipy = psds_scipy[:, mask]
    # make sure they match
    assert_array_almost_equal(psds_mne, psds_scipy)
    assert_array_equal(freqs_mne, freqs_scipy)
    assert psds_mne.shape == (data.shape[0], len(freqs_mne))
    assert psds_scipy.shape == (data.shape[0], len(freqs_scipy))
    assert np.sum(freqs_mne < 0) == 0
    assert np.sum(freqs_scipy < 0) == 0
    assert np.sum(psds_mne < 0) == 0
    assert np.sum(psds_scipy < 0) == 0


def test_psd_array_welch_n_jobs():
    """Test that n_jobs works even with more jobs than channels."""
    data = np.zeros((1, 2048))
    psd_array_welch(data, 1024, n_jobs=1)
    psd_array_welch(data, 1024, n_jobs=2)


def test_psd_nan_in_data():
    """psd_array_welch should fail if +Inf lies inside analyzed samples."""
    n_samples, n_fft, n_overlap = 2048, 256, 128
    rng = np.random.RandomState(0)
    x = rng.standard_normal(size=(2, n_samples))
    # Put +Inf inside the series; this falls within Welch windows
    x[0, 800] = np.inf  # Channel 0 has Inf → bad channel
    with pytest.warns(RuntimeWarning, match="Non-finite values"):
        psds, freqs = psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)

    # Channel 0 is contaminated → NaN PSD
    assert np.isnan(psds[0]).all()

    # Channel 1 is clean → has finite PSD values
    assert np.isfinite(psds[1]).any()


def test_psd_misaligned_nan_across_channels():
    """If NaNs are present but masks are NOT aligned across channels."""
    n_samples, n_fft, n_overlap = 2048, 256, 128
    rng = np.random.RandomState(42)
    x = rng.standard_normal(size=(2, n_samples))
    # NaN only in ch0; ch1 has no NaN => masks not aligned -> should raise
    x[0, 500] = np.nan
    with pytest.warns(RuntimeWarning, match="Non-finite values"):
        psds, freqs = psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
    # Bad channel gets NaN PSD
    assert np.isnan(psds[0]).all()

    # Good channel retains finite values
    assert np.isfinite(psds[1]).any()
